{
 "cells": [
  {
    "cell_type": "markdown",
    "metadata": {},
    "source": [
     "# 由于还不支持numpy，目前notebook中的 \"Part B: Encrypted Aggregation\" 无法正常使用。\n",
     "\n",
     "如需查看状态，请访问 https://github.com/OpenMined/PySyft/issues/2771.\n",
     "\n",
     "只要B部分恢复正常，这段话就会被移除。"
    ]
   },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Part 10: 安全聚合的联邦学习\n",
    "\n",
    "在最后几节中，我们通过构建几个简单的程序来学习加密计算。 在本节中，我们将返回到[第4部分的联合学习演示](https://github.com/OpenMined/PySyft/blob/dev/examples/tutorials/Part%204%20-%20Federated% 20Learning％20via％20Trusted％20Aggregator.ipynb)，我们在那里有一个“受信任的聚合器”，负责平均多个工作机的模型更新。\n",
    "\n",
    "现在，我们将使用新的工具进行加密计算，以删除此受信任的聚合器，因为它不理想，因为它假定我们可以找到足够值得信任的人来访问此敏感信息。这并非总是如此。\n",
    "\n",
    "因此，在这个笔记本中，我们将展示如何使用SMPC来执行安全聚合，这样我们就不需要“受信任的聚合器”。\n",
    "\n",
    "作者:\n",
    "- Theo Ryffel - Twitter: [@theoryffel](https://twitter.com/theoryffel)\n",
    "- Andrew Trask - Twitter: [@iamtrask](https://twitter.com/iamtrask)\n",
    "\n",
    "中文版译者：\n",
    "- Hou Wei - github：[@dljgs1](https://github.com/dljgs1)\n",
    "- Haofan Wang - github：[@haofanwang](https://github.com/haofanwang)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 1: 普通的联邦学习\n",
    "\n",
    "首先，这是一些代码，用于在Boston Housing Dataset上执行经典的联邦学习。这部分代码分为几个部分。\n",
    "\n",
    "### 配置"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "from torch.utils.data import TensorDataset, DataLoader\n",
    "\n",
    "class Parser:\n",
    "    \"\"\"Parameters for training\"\"\"\n",
    "    def __init__(self):\n",
    "        self.epochs = 10\n",
    "        self.lr = 0.001\n",
    "        self.test_batch_size = 8\n",
    "        self.batch_size = 8\n",
    "        self.log_interval = 10\n",
    "        self.seed = 1\n",
    "    \n",
    "args = Parser()\n",
    "\n",
    "torch.manual_seed(args.seed)\n",
    "kwargs = {}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 加载数据集"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('../data/BostonHousing/boston_housing.pickle','rb') as f:\n",
    "    ((X, y), (X_test, y_test)) = pickle.load(f)\n",
    "\n",
    "X = torch.from_numpy(X).float()\n",
    "y = torch.from_numpy(y).float()\n",
    "X_test = torch.from_numpy(X_test).float()\n",
    "y_test = torch.from_numpy(y_test).float()\n",
    "# preprocessing\n",
    "mean = X.mean(0, keepdim=True)\n",
    "dev = X.std(0, keepdim=True)\n",
    "mean[:, 3] = 0. # the feature at column 3 is binary,\n",
    "dev[:, 3] = 1.  # so we don't standardize it\n",
    "X = (X - mean) / dev\n",
    "X_test = (X_test - mean) / dev\n",
    "train = TensorDataset(X, y)\n",
    "test = TensorDataset(X_test, y_test)\n",
    "train_loader = DataLoader(train, batch_size=args.batch_size, shuffle=True, **kwargs)\n",
    "test_loader = DataLoader(test, batch_size=args.test_batch_size, shuffle=True, **kwargs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 神经网络结构"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Net(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(Net, self).__init__()\n",
    "        self.fc1 = nn.Linear(13, 32)\n",
    "        self.fc2 = nn.Linear(32, 24)\n",
    "        self.fc3 = nn.Linear(24, 1)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = x.view(-1, 13)\n",
    "        x = F.relu(self.fc1(x))\n",
    "        x = F.relu(self.fc2(x))\n",
    "        x = self.fc3(x)\n",
    "        return x\n",
    "\n",
    "model = Net()\n",
    "optimizer = optim.SGD(model.parameters(), lr=args.lr)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 挂钩PyTorch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import syft as sy\n",
    "\n",
    "hook = sy.TorchHook(torch)\n",
    "bob = sy.VirtualWorker(hook, id=\"bob\")\n",
    "alice = sy.VirtualWorker(hook, id=\"alice\")\n",
    "james = sy.VirtualWorker(hook, id=\"james\")\n",
    "\n",
    "compute_nodes = [bob, alice]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**将数据发送给工作机** \n",
    "通常工作机已经拥有了数据，这只是出于演示目的，我们选择手动发送"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_distributed_dataset = []\n",
    "\n",
    "for batch_idx, (data,target) in enumerate(train_loader):\n",
    "    data = data.send(compute_nodes[batch_idx % len(compute_nodes)])\n",
    "    target = target.send(compute_nodes[batch_idx % len(compute_nodes)])\n",
    "    train_distributed_dataset.append((data, target))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 训练函数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(epoch):\n",
    "    model.train()\n",
    "    for batch_idx, (data,target) in enumerate(train_distributed_dataset):\n",
    "        worker = data.location\n",
    "        model.send(worker)\n",
    "\n",
    "        optimizer.zero_grad()\n",
    "        # update the model\n",
    "        pred = model(data)\n",
    "        loss = F.mse_loss(pred.view(-1), target)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        model.get()\n",
    "            \n",
    "        if batch_idx % args.log_interval == 0:\n",
    "            loss = loss.get()\n",
    "            print('Train Epoch: {} [{}/{} ({:.0f}%)]\\tLoss: {:.6f}'.format(\n",
    "                epoch, batch_idx * data.shape[0], len(train_loader),\n",
    "                       100. * batch_idx / len(train_loader), loss.item()))\n",
    "        \n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 测试函数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def test():\n",
    "    model.eval()\n",
    "    test_loss = 0\n",
    "    for data, target in test_loader:\n",
    "        output = model(data)\n",
    "        test_loss += F.mse_loss(output.view(-1), target, reduction='sum').item() # sum up batch loss\n",
    "        pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability\n",
    "        \n",
    "    test_loss /= len(test_loader.dataset)\n",
    "    print('\\nTest set: Average loss: {:.4f}\\n'.format(test_loss))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 训练模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import time"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "t = time.time()\n",
    "\n",
    "for epoch in range(1, args.epochs + 1):\n",
    "    train(epoch)\n",
    "\n",
    "    \n",
    "total_time = time.time() - t\n",
    "print('Total', round(total_time, 2), 's')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 计算性能"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 2: 添加加密聚合\n",
    "\n",
    "现在，我们将略微修改此示例，以使用加密来聚合梯度。 不同的主要部分实际上是`train（）`函数中的1或2行代码，我们将指出。目前，让我们重新处理数据并初始化bob和alice的模型。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "remote_dataset = (list(),list())\n",
    "\n",
    "train_distributed_dataset = []\n",
    "\n",
    "for batch_idx, (data,target) in enumerate(train_loader):\n",
    "    data = data.send(compute_nodes[batch_idx % len(compute_nodes)])\n",
    "    target = target.send(compute_nodes[batch_idx % len(compute_nodes)])\n",
    "    remote_dataset[batch_idx % len(compute_nodes)].append((data, target))\n",
    "\n",
    "def update(data, target, model, optimizer):\n",
    "    model.send(data.location)\n",
    "    optimizer.zero_grad()\n",
    "    pred = model(data)\n",
    "    loss = F.mse_loss(pred.view(-1), target)\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "    return model\n",
    "\n",
    "bobs_model = Net()\n",
    "alices_model = Net()\n",
    "\n",
    "bobs_optimizer = optim.SGD(bobs_model.parameters(), lr=args.lr)\n",
    "alices_optimizer = optim.SGD(alices_model.parameters(), lr=args.lr)\n",
    "\n",
    "models = [bobs_model, alices_model]\n",
    "params = [list(bobs_model.parameters()), list(alices_model.parameters())]\n",
    "optimizers = [bobs_optimizer, alices_optimizer]\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 建立我们的训练逻辑\n",
    "\n",
    "唯一的**真正的**差异是此训练方法的内部。让我们逐步讲解\n",
    "\n",
    "### A: 训练:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 选择训练哪个batch\n",
    "data_index = 0\n",
    "# 更新远程模型\n",
    "# 我们可以在进行此操作之前对其进行多次迭代，但是在这里每个工人只进行一次迭代\n",
    "for remote_index in range(len(compute_nodes)):\n",
    "    data, target = remote_dataset[remote_index][data_index]\n",
    "    models[remote_index] = update(data, target, models[remote_index], optimizers[remote_index])\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### B: 加密聚合"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 创建一个列表，我们将在其中存储我们的加密模型平均值\n",
    "new_params = list()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 遍历每个参数\n",
    "for param_i in range(len(params[0])):\n",
    "\n",
    "    # 对每个工作机\n",
    "    spdz_params = list()\n",
    "    for remote_index in range(len(compute_nodes)):\n",
    "        \n",
    "        # 从每个工作机中选择相同的参数并复制\n",
    "        copy_of_parameter = params[remote_index][param_i].copy()\n",
    "        \n",
    "        \n",
    "        # 由于SMPC只能使用整数（不能使用浮点数）\n",
    "        # 因此我们需要使用Integers存储十进制信息。\n",
    "        # 换句话说，我们需要使用“固定精度”编码。\n",
    "        fixed_precision_param = copy_of_parameter.fix_precision()\n",
    "        \n",
    "        # 现在我们在远程计算机上对其进行加密。\n",
    "        # 注意，fixed_precision_param“已经”是一个指针。\n",
    "        # 因此，当我们调用share时，它实际上是对指向的数据进行加密。\n",
    "        # 而它会返回一个指向MPC秘密共享对象的指针，也就是我们需要的共享分片。\n",
    "        \n",
    "        encrypted_param = fixed_precision_param.share(bob, alice, crypto_provider=james)\n",
    "        \n",
    "        # 现在我们获取指向MPC共享值的指针\n",
    "        param = encrypted_param.get()\n",
    "        \n",
    "        # 保存参数，以便我们可以使用工作机的相同参数取平均值\n",
    "        spdz_params.append(param)\n",
    "\n",
    "    # 来自多个工作人员的平均参数，将它们提取到本地计算机上\n",
    "    # 以固定精度解密和解码，再返回一个浮点数\n",
    "    new_param = (spdz_params[0] + spdz_params[1]).get().float_precision()/2\n",
    "    \n",
    "    # 保存新的平均参数\n",
    "    new_params.append(new_param)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### C: 清理"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with torch.no_grad():\n",
    "    for model in params:\n",
    "        for param in model:\n",
    "            param *= 0\n",
    "\n",
    "    for model in models:\n",
    "        model.get()\n",
    "\n",
    "    for remote_index in range(len(compute_nodes)):\n",
    "        for param_index in range(len(params[remote_index])):\n",
    "            params[remote_index][param_index].set_(new_params[param_index])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 把它们放在一起！！\n",
    "\n",
    "现在我们知道了每个步骤，我们可以将所有步骤放到一个训练循环中！"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(epoch):\n",
    "    for data_index in range(len(remote_dataset[0])-1):\n",
    "        # update remote models\n",
    "        for remote_index in range(len(compute_nodes)):\n",
    "            data, target = remote_dataset[remote_index][data_index]\n",
    "            models[remote_index] = update(data, target, models[remote_index], optimizers[remote_index])\n",
    "\n",
    "        # encrypted aggregation\n",
    "        new_params = list()\n",
    "        for param_i in range(len(params[0])):\n",
    "            spdz_params = list()\n",
    "            for remote_index in range(len(compute_nodes)):\n",
    "                spdz_params.append(params[remote_index][param_i].copy().fix_precision().share(bob, alice, crypto_provider=james).get())\n",
    "\n",
    "            new_param = (spdz_params[0] + spdz_params[1]).get().float_precision()/2\n",
    "            new_params.append(new_param)\n",
    "\n",
    "        # cleanup\n",
    "        with torch.no_grad():\n",
    "            for model in params:\n",
    "                for param in model:\n",
    "                    param *= 0\n",
    "\n",
    "            for model in models:\n",
    "                model.get()\n",
    "\n",
    "            for remote_index in range(len(compute_nodes)):\n",
    "                for param_index in range(len(params[remote_index])):\n",
    "                    params[remote_index][param_index].set_(new_params[param_index])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def test():\n",
    "    models[0].eval()\n",
    "    test_loss = 0\n",
    "    for data, target in test_loader:\n",
    "        output = models[0](data)\n",
    "        test_loss += F.mse_loss(output.view(-1), target, reduction='sum').item() # sum up batch loss\n",
    "        pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability\n",
    "        \n",
    "    test_loss /= len(test_loader.dataset)\n",
    "    print('Test set: Average loss: {:.4f}\\n'.format(test_loss))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "t = time.time()\n",
    "\n",
    "for epoch in range(args.epochs):\n",
    "    print(f\"Epoch {epoch + 1}\")\n",
    "    train(epoch)\n",
    "    test()\n",
    "\n",
    "    \n",
    "total_time = time.time() - t\n",
    "print('Total', round(total_time, 2), 's')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 恭喜!!! 是时候加入社区了!\n",
    "\n",
    "祝贺您完成本笔记本教程！ 如果您喜欢此方法，并希望加入保护隐私、去中心化AI和AI供应链（数据）所有权的运动，则可以通过以下方式做到这一点！\n",
    "\n",
    "### 给 PySyft 加星\n",
    "\n",
    "帮助我们的社区的最简单方法是仅通过给GitHub存储库加注星标！ 这有助于提高人们对我们正在构建的出色工具的认识。\n",
    "\n",
    "- [Star PySyft](https://github.com/OpenMined/PySyft)\n",
    "\n",
    "\n",
    "### 加入我们的 Slack!\n",
    "\n",
    "保持最新进展的最佳方法是加入我们的社区！ 您可以通过填写以下表格来做到这一点[http://slack.openmined.org](http://slack.openmined.org)\n",
    "\n",
    "### 加入代码项目!\n",
    "\n",
    "对我们的社区做出贡献的最好方法是成为代码贡献者！ 您随时可以转到PySyft GitHub的Issue页面并过滤“projects”。这将向您显示所有概述，选择您可以加入的项目！如果您不想加入项目，但是想做一些编码，则还可以通过搜索标记为“good first issue”的GitHub问题来寻找更多的“一次性”微型项目。\n",
    "\n",
    "- [PySyft Projects](https://github.com/OpenMined/PySyft/issues?q=is%3Aopen+is%3Aissue+label%3AProject)\n",
    "- [Good First Issue Tickets](https://github.com/OpenMined/PySyft/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)\n",
    "\n",
    "### 捐赠\n",
    "\n",
    "如果您没有时间为我们的代码库做贡献，但仍想提供支持，那么您也可以成为Open Collective的支持者。所有捐款都将用于我们的网络托管和其他社区支出，例如黑客马拉松和聚会！\n",
    "\n",
    "[OpenMined's Open Collective Page](https://opencollective.com/openmined)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  },
  "latex_envs": {
   "LaTeX_envs_menu_present": true,
   "autoclose": false,
   "autocomplete": true,
   "bibliofile": "biblio.bib",
   "cite_by": "apalike",
   "current_citInitial": 1,
   "eqLabelWithNumbers": true,
   "eqNumInitial": 1,
   "hotkeys": {
    "equation": "Ctrl-E",
    "itemize": "Ctrl-I"
   },
   "labels_anchors": false,
   "latex_user_defs": false,
   "report_style_numbering": false,
   "user_envs_cfg": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
